# load libraries
library(tidyverse)
library(here)
library(tidymodels)
library(vip)
library(patchwork)


# source
source("replication/00-funcs.R")

# read data
clean_df = read_rds(here("replication", "data", "lasso-data.rds"))



# dictionary
dict = tibble::tribble(
  ~dirty,                             ~clean,
  "race_bin",                  "White/non-white",
  "displace",                  "Displaced (y/n)",
  "man",                       "Sex (male)",
  "age_num",                      "Age (years)",
  "hhsize",                   "Household size",
  "indrural",               "% rural population",
  "lpop",              "Population (logged)",
  "discapital",              "Distance to capital",
  "nbi",          "Municipal poverty index",
  "laltura",                "Altitude (logged)",
  "homicidios",        "Municipal homicides (avg)",
  "AUC_",    "AUC presence (prop. of years)",
  "FARC_",   "FARC presence (prop. of years)",
  "ELN_",    "ELN presence (prop. of years)",
  "desplazados_expulsion", "Number of people displaced (avg)",
  "other_victim",  "Other victimization experiences",
  "prop.score",      "Propensity score (distance)",
  "lrdp_num",            "LRDP programming area",
  "read",                         "literacy",
  "ed", "Years of education", 
  "news", "News consumption", 
  "dis_land", "Displaced for land", 
  "displace_force", "Displaced by force",
  "law_know", "Knows Victim's Law", 
  "yrs_dis", "Years since displacement", 
  "mtng_victim", "Attends victims' meetings", 
  "risk_easy", "Risk aversion",
  "risk_ladder", "Risk-seeking",
  "income", "Monthly income", 
  "homi_per_1k", "Violence against social leaders"
)



# who returns? ----------------------------------------------------------------------



# no missing
disp_full = 
  clean_df %>% 
  filter(displace == 1) %>% 
  select(dis_return, man, age_num, hhsize, ed, income, news, 
         homi_per_1k, displace_force, race_bin, yrs_dis, risk_ladder) %>% 
  mutate(dis_return = as.factor(ifelse(dis_return == 1, "yes", "no"))) %>% 
  drop_na()



# set up samples
set.seed(1990)
data_split <- initial_split(disp_full, prop = .9)
data_train <- training(data_split)
data_test <- testing(data_split)



# prep data
data_rec <- recipe(dis_return ~ ., data = data_train) %>%
  # normalize all non-outcome variables
  step_normalize(all_numeric(), -all_outcomes())

# set up workflow
wf <- workflow() %>%
  add_recipe(data_rec)

# tune lambda
data_boot <- bootstraps(data_train)

tune_spec <- logistic_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet", standardize = FALSE)

lambda_grid <- grid_regular(penalty(), levels = 300)


# tune
doParallel::registerDoParallel()

set.seed(1990)
lasso_grid <- tune_grid(
  wf %>% add_model(tune_spec),
  resamples = data_boot,
  grid = lambda_grid
)



# results
lasso_grid %>%
  collect_metrics() %>%
  ggplot(aes(penalty, mean, color = .metric)) +
  geom_errorbar(aes(
    ymin = mean - std_err,
    ymax = mean + std_err
  ),
  alpha = 0.5
  ) +
  geom_line(size = 1.5) +
  facet_wrap(~.metric, scales = "free", nrow = 2) +
  scale_x_log10() +
  theme(legend.position = "none")


# pick best model
lowest_rmse <- lasso_grid %>%
  select_best("roc_auc")

final_lasso <- finalize_workflow(
  wf %>% add_model(tune_spec),
  lowest_rmse
)


# coefficient estimates from Lasso
p1 = final_lasso %>%
  fit(data_train) %>%
  pull_workflow_fit() %>%
  vi(lambda = lowest_rmse$penalty) %>%
  left_join(dict, by = c("Variable" = "dirty")) %>% 
  mutate(Importance = if_else(Sign == "NEG", Importance*-1, Importance)) %>% 
  mutate(clean = fct_reorder(clean, Importance)) %>%
  mutate(group = "Was respondent able to return home?") %>% 
  ggplot(aes(x = Importance, y = clean, fill = Sign)) +
  geom_col() +
  scale_x_continuous(limits = c(-.4, .4)) +
  labs(y = NULL, x = "Variable importance (LASSO)") +
  scale_fill_brewer(palette = "Set1") + 
  facet_wrap(vars(group)) + 
  theme(legend.position = "none", 
        strip.text = element_text(face = "bold")) + 
  geom_vline(xintercept = 0, size = 1, color = "grey80")

p1

ggsave(here("replication", "figures", "lasso-return-vip.pdf"), 
       device = cairo_pdf)


auc = last_fit(
  final_lasso,
  data_split
) %>%
  collect_metrics()



p1auc = last_fit(final_lasso, data_split) %>%
  collect_predictions() %>%
  roc_curve(dis_return, .pred_yes, event_level = "second") %>%
  mutate(group = "Outcome: respondent able to return") %>% 
  ggplot(aes(x = 1 - specificity, y = sensitivity)) +
  geom_line(size = 1.5, color = "midnightblue") +
  geom_abline(lty = 2, alpha = 0.5, color = "gray50", size = 1.2) + 
  ggrepel::geom_label_repel(data = tibble(specificity = .5, sensitivity = .25),
                            color = "midnightblue",
                            label = glue::glue("Out-of-sample AUC: {round(auc$.estimate[auc$.metric == 'roc_auc'], 2)}")) + 
  facet_wrap(vars(group)) + 
  theme(strip.text = element_text(face = "bold"))
p1auc

ggsave(here("replication", "figures", "lasso-return-auc.pdf"), 
       device = cairo_pdf)


# who seeks restitution? -------------------------------------------------------------


# no missing
disp_full = 
  clean_df %>% 
  filter(displace == 1) %>% 
  select(rest_yn, man, age_num, hhsize, ed, income, news, 
         displace_force, law_know, indrural, homi_per_1k,
         race_bin, yrs_dis, mtng_victim, risk_ladder) %>% 
  mutate(rest_yn = as.factor(ifelse(rest_yn == 1, "yes", "no"))) %>% 
  drop_na()



# set up samples
set.seed(1990)
data_split <- initial_split(disp_full, strata = rest_yn)
data_train <- training(data_split)
data_test <- testing(data_split)



# prep data
data_rec <- recipe(rest_yn ~ ., data = data_train) %>%
  step_normalize(all_numeric(), -all_outcomes())

# set up workflow
wf <- workflow() %>%
  add_recipe(data_rec)

# tune lambda
data_boot <- bootstraps(data_train, strata = rest_yn)

tune_spec <- logistic_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet", standardize = FALSE)

lambda_grid <- grid_regular(penalty(), levels = 300)


# tune
doParallel::registerDoParallel()

set.seed(1990)
lasso_grid <- tune_grid(
  wf %>% add_model(tune_spec),
  resamples = data_boot,
  grid = lambda_grid
)



# results
lasso_grid %>%
  collect_metrics() %>%
  ggplot(aes(penalty, mean, color = .metric)) +
  geom_errorbar(aes(
    ymin = mean - std_err,
    ymax = mean + std_err
  ),
  alpha = 0.5
  ) +
  geom_line(size = 1.5) +
  facet_wrap(~.metric, scales = "free", nrow = 2) +
  scale_x_log10() +
  theme(legend.position = "none")


# pick best model
lowest_rmse <- lasso_grid %>%
  select_best("roc_auc")

final_lasso <- finalize_workflow(
  wf %>% add_model(tune_spec),
  lowest_rmse
)


# coefficient estimates from Lasso
p2 = final_lasso %>%
  fit(data_train) %>%
  pull_workflow_fit() %>%
  vi(lambda = lowest_rmse$penalty) %>%
  left_join(dict, by = c("Variable" = "dirty")) %>% 
  mutate(Importance = if_else(Sign == "NEG", Importance*-1, Importance)) %>% 
  mutate(clean = fct_reorder(clean, Importance)) %>%
  mutate(group = "Did respondent seek restitution?") %>% 
  ggplot(aes(x = Importance, y = clean, fill = Sign)) +
  geom_col() +
  scale_x_continuous(limits = c(-.4, .4)) +
  labs(y = NULL, x = "Variable importance (LASSO)") +
  scale_fill_brewer(palette = "Set1") + 
  facet_wrap(vars(group)) + 
  theme(legend.position = "none", 
        strip.text = element_text(face = "bold")) + 
  geom_vline(xintercept = 0, size = 1, color = "grey80")
p2

ggsave(here("replication", "figures", "lasso-rest-vip.pdf"), 
       device = cairo_pdf)


auc = last_fit(
  final_lasso,
  data_split
) %>%
  collect_metrics()



p2_auc = last_fit(final_lasso, data_split) %>%
  collect_predictions() %>%
  roc_curve(rest_yn, .pred_yes, event_level = "second") %>%
  mutate(group = "Outcome: respondent sought restitution") %>% 
  ggplot(aes(x = 1 - specificity, y = sensitivity)) +
  geom_line(size = 1.5, color = "midnightblue") +
  geom_abline(lty = 2, alpha = 0.5, color = "gray50", size = 1.2) + 
  ggrepel::geom_label_repel(data = tibble(specificity = .5, sensitivity = .25),
                            color = "midnightblue",
                            label = glue::glue("Out-of-sample AUC: {round(auc$.estimate[auc$.metric == 'roc_auc'], 2)}")) +
  facet_wrap(vars(group)) + 
  theme(strip.text = element_text(face = "bold"))
p2_auc

ggsave(here("replication", "figures", "lasso-rest-auc.pdf"), device = cairo_pdf)


# combine
p1 / p2

ggsave(here("replication", "figures", "lasso-vip.pdf"), device = cairo_pdf)


p1auc / p2_auc
ggsave(here("replication", "figures", "lasso-auc.pdf"), device = cairo_pdf)



